{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 📚 Adding a Custom Dataset Tutorial\n",
    "\n",
    "## 🎯 Tutorial Overview\n",
    "\n",
    "This comprehensive guide walks you through the process of integrating your custom dataset into our library. The process is divided into three main steps:\n",
    "\n",
    "1. **Dataset Creation** 🔨\n",
    "   - Implement data loading mechanisms\n",
    "   - Define preprocessing steps\n",
    "   - Structure data in the required format\n",
    "\n",
    "2. **Integrate with Dataset APIs** 🔄\n",
    "   - Add dataset to the library framework\n",
    "   - Ensure compatibility with existing systems\n",
    "   - Set up proper inheritance structure\n",
    "\n",
    "3. **Configuration Setup** ⚙️\n",
    "   - Define dataset parameters\n",
    "   - Specify data paths and formats\n",
    "   - Configure preprocessing options\n",
    "\n",
    "## 📋 Tutorial Structure\n",
    "\n",
    "This tutorial follows a unique structure to provide the clearest possible learning experience:\n",
    "\n",
    "> 💡 **Main Notebook (Current File)**\n",
    "> - High-level concepts and explanations\n",
    "> - Step-by-step workflow description\n",
    "> - References to implementation files\n",
    "\n",
    "> 📁 **Supporting Files**\n",
    "> - Detailed code implementations\n",
    "> - Specific examples and use cases\n",
    "> - Technical documentation\n",
    "\n",
    "### 🛠️ Technical Framework\n",
    "\n",
    "This tutorial demonstrates custom dataset integration using:\n",
    "- `torch_geometric.data.InMemoryDataset` as the base class\n",
    "- <TB_name> library's dataset management system\n",
    "\n",
    "### 🎓 Important Notes\n",
    "\n",
    "- To make the learning process concrete, we'll work with a practical toy \"language\" dataset example:\n",
    "- While we use the \"language\" dataset as an example, all file references use the generic `<dataset_name>` format for better generalization\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 1: Create a Dataset 🛠️\n",
    "\n",
    "## Overview\n",
    "\n",
    "Adding your custom dataset to <TB_name> requires implementing specific loading and preprocessing functionality. We utilize the `torch_geometric.data.InMemoryDataset` interface to make this process straightforward.\n",
    "\n",
    "## Required Methods\n",
    "\n",
    "To implement your dataset, you need to override two key methods from the `torch_geometric.data.InMemoryDataset` class:\n",
    "\n",
    "- `download()`: Handles dataset acquisition\n",
    "- `process()`: Manages data preprocessing\n",
    "\n",
    "> 💡 **Reference Implementation**: For a complete example, check `topobench/data/datasets/language_dataset.py`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deep Dive: The Download Method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `download()` method is responsible for acquiring dataset files from external resources. Let's examine its implementation using our language dataset example, where we store data in a GoogleDrive-hosted zip file.\n",
    "\n",
    "#### Implementation Steps\n",
    "\n",
    "1. **Download Data** 📥\n",
    "  - Fetch data from the specified source URL\n",
    "  - Save to the raw directory\n",
    "\n",
    "2. **Extract Content** 📦\n",
    "  - Unzip the downloaded file\n",
    "  - Place contents in appropriate directory\n",
    "\n",
    "3. **Organize Files** 📂\n",
    "  - Move extracted files to named folders\n",
    "  - Clean up temporary files and directories\n",
    "\n",
    "#### Code Implementation\n",
    "\n",
    "```python\n",
    "def download(self) -> None:\n",
    "    r\"\"\"Download the dataset from a URL and saves it to the raw directory.\n",
    "\n",
    "    Raises:\n",
    "        FileNotFoundError: If the dataset URL is not found.\n",
    "    \"\"\"\n",
    "    # Step 1: Download data from the source\n",
    "    self.url = self.URLS[self.name]\n",
    "    self.file_format = self.FILE_FORMAT[self.name]\n",
    "    download_file_from_drive(\n",
    "        file_link=self.url,\n",
    "        path_to_save=self.raw_dir,\n",
    "        dataset_name=self.name,\n",
    "        file_format=self.file_format,\n",
    "    )\n",
    "    \n",
    "    # Step 2: extract zip file\n",
    "    folder = self.raw_dir\n",
    "    filename = f\"{self.name}.{self.file_format}\"\n",
    "    path = osp.join(folder, filename)\n",
    "    extract_zip(path, folder)\n",
    "    # Delete zip file\n",
    "    os.unlink(path)\n",
    "    \n",
    "    # Step 3: organize files\n",
    "    # Move files from osp.join(folder, name_download) to folder\n",
    "    for file in os.listdir(osp.join(folder, self.name)):\n",
    "        shutil.move(osp.join(folder, self.name, file), folder)\n",
    "    # Delete osp.join(folder, self.name) dir\n",
    "    shutil.rmtree(osp.join(folder, self.name))\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Dive: The Process Method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `process()` method handles data preprocessing and organization. Here's the method's structure:\n",
    "\n",
    "```python\n",
    "def process(self) -> None:\n",
    "   r\"\"\"Handle the data for the dataset.\n",
    "   \n",
    "   This method loads the Language dataset, applies preprocessing \n",
    "   transformations, and saves processed data.\"\"\"\n",
    "\n",
    "   # Step 1: extract the data\n",
    "   ...  # Convert raw data to list of torch_geometric.data.Data objects\n",
    "\n",
    "   # Step 2: collate the graphs\n",
    "   self.data, self.slices = self.collate(graph_sentences)\n",
    "\n",
    "   # Step 3: save processed data\n",
    "   fs.torch_save(\n",
    "       (self._data.to_dict(), self.slices, {}, self._data.__class__),\n",
    "       self.processed_paths[0],\n",
    "   )\n",
    "\n",
    "\n",
    "```self.collate``` -- Collates a list of Data or HeteroData objects to the internal storage format; meaning that it transforms a list of torch.data.Data objectis into one torch.data.BaseData.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 2: Integrate with Dataset APIs 🔄\n",
    "\n",
    "Now that we have created a dataset class, we need to integrate it with the library. In this section we describe where to add the dataset files and how to make it available through data loaders."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's how to structure your files, the files highlighted with ** are going to be updated: \n",
    "```yaml\n",
    "topobench/\n",
    "├── data/\n",
    "│   ├── datasets/\n",
    "│   │   ├── **init.py**\n",
    "│   │   ├── base.py\n",
    "│   │   ├── <dataset_name>.py   # Your dataset file\n",
    "│   │   └── ...\n",
    "│   ├── loaders/\n",
    "│   │   ├── init.py\n",
    "│   │   ├── base.py\n",
    "│   │   ├── graph/\n",
    "│   │   │   ├── <loader_name>.py   # Your loader file\n",
    "│   │   ├── hypergraph/\n",
    "│   │   │   ├── <loader_name>.py   # Your loader file\n",
    "│   │   ├── .../\n",
    "```\n",
    "\n",
    "To make your dataset available to library:\n",
    "\n",
    "The file ```<dataset_name>.py```  has been created during the previous steps (`us_county_demos_dataset.py` in our case) and should be placed in the `topobench/data/datasets/` directory. \n",
    "\n",
    "\n",
    "The registry `topobench/data/datasets/__init__.py` discovers the files in `topobench/data/datasets` and updates `__all__` variable of `topobench/data/datasets/__init__.py` automatically. Hence there is no need to update the `__init__.py` file manually to allow your dataset to be loaded by the library. Simply creare a file `<dataset_name>.py` and place it in the  `topobench/data/datasets/` directory.\n",
    "\n",
    "------------------------------------------------------------------------------------------------\n",
    "\n",
    "Next it is required to update the data loader system. Modify the loader file (`topobench/data/loaders/loaders.py`:) to include your custom dataset:\n",
    "\n",
    "For the the example dataset we add the following into the file ```topobench/data/loaders/graph/us_county_demos_dataset_loader.py``` which consist of the following:\n",
    "\n",
    "```python\n",
    "class USCountyDemosDatasetLoader(AbstractLoader):\n",
    "    \"\"\"Load US County Demos dataset with configurable year and task variable.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    parameters : DictConfig\n",
    "        Configuration parameters containing:\n",
    "            - data_dir: Root directory for data\n",
    "            - data_name: Name of the dataset\n",
    "            - year: Year of the dataset (if applicable)\n",
    "            - task_variable: Task variable for the dataset\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, parameters: DictConfig) -> None:\n",
    "        super().__init__(parameters)\n",
    "\n",
    "    def load_dataset(self) -> USCountyDemosDataset:\n",
    "        \"\"\"Load the US County Demos dataset.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        USCountyDemosDataset\n",
    "            The loaded US County Demos dataset with the appropriate `data_dir`.\n",
    "\n",
    "        Raises\n",
    "        ------\n",
    "        RuntimeError\n",
    "            If dataset loading fails.\n",
    "        \"\"\"\n",
    "\n",
    "        dataset = self._initialize_dataset()\n",
    "        self.data_dir = self._redefine_data_dir(dataset)\n",
    "        return dataset\n",
    "\n",
    "    def _initialize_dataset(self) -> USCountyDemosDataset:\n",
    "        \"\"\"Initialize the US County Demos dataset.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        USCountyDemosDataset\n",
    "            The initialized dataset instance.\n",
    "        \"\"\"\n",
    "        return USCountyDemosDataset(\n",
    "            root=str(self.root_data_dir),\n",
    "            name=self.parameters.data_name,\n",
    "            parameters=self.parameters,\n",
    "        )\n",
    "\n",
    "    def _redefine_data_dir(self, dataset: USCountyDemosDataset) -> Path:\n",
    "        \"\"\"Redefine the data directory based on the chosen (year, task_variable) pair.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        dataset : USCountyDemosDataset\n",
    "            The dataset instance.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        Path\n",
    "            The redefined data directory path.\n",
    "        \"\"\"\n",
    "        return dataset.processed_root\n",
    "```\n",
    "where the method ```load_dataset``` is required while other methods are optional used for convenience and structure.\n",
    "\n",
    "### Notes:\n",
    "- The  ```load_dataset``` of ```AbstractLoader``` class requires to return ```torch.utils.data.Dataset``` object. \n",
    "- **Important:** to allow the automatic registering of the loader, make sure to include \"DatasetLoader\" into name of loader class (Example: USCountyDemos**DatasetLoader**)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 3: Define Configuration 🔧"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've integrated our dataset, we need to define its configuration parameters. In this section, we'll explain how to create and structure the configuration file for your dataset.\n",
    "\n",
    "## Configuration File Structure\n",
    "Create a new YAML file for your dataset in `configs/dataset/<dataset_name>.yaml` with the following structure:\n",
    "\n",
    "\n",
    "### While creating a configuration file, you will need to specify: \n",
    "\n",
    "1) Loader class (`topobench.data.loaders.USCountyDemosDatasetLoader`) for automatic instantialization inside the provided pipeline and the parameters for the loader.\n",
    "```yaml\n",
    "# Dataset loader config\n",
    "loader:\n",
    "  _target_: topobench.data.loaders.USCountyDemosDatasetLoader\n",
    "  parameters: \n",
    "    data_domain: graph             # Primary data domain. Options: ['graph', 'hypergrpah', 'cell, 'simplicial']\n",
    "    data_type: cornel              # Data type. String emphasizing from where dataset come from. \n",
    "    data_name: US-county-demos     # Name of the dataset\n",
    "    year: 2012                     # In the case of US-county-demos there are multiple version of this dataset. Options:[2012, 2016]\n",
    "    task_variable: 'Election'      # Different target variable used as target. Options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate']\n",
    "    data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}\n",
    "``` \n",
    "\n",
    "2) The dataset parameters: \n",
    "\n",
    "```yaml\n",
    "# Dataset parameters\n",
    "parameters:\n",
    "  num_features: 6         # Number of features in the dataset\n",
    "  num_classes: 1          # Dimentuin of the target variable\n",
    "  task: regression        # Dataset task. Options: [classification, regression]\n",
    "  loss_type: mse          # Task-specific loss function\n",
    "  monitor_metric: mae     # Metric to monitor during training\n",
    "  task_level: node        # Task level. Options: [classification, regression]\n",
    "```\n",
    "\n",
    "3) The dataset split parameters: \n",
    "```yaml\n",
    "#splits\n",
    "split_params:\n",
    "  learning_setting: transductive      # Type of learning. Options:['transductive', 'inductive']\n",
    "  data_seed: 0                        # Seed for data splitting\n",
    "  split_type: random                  # Type of splitting. Options: ['k-fold', 'random']\n",
    "  k: 10                               # Number of folds in case of \"k-fold\" cross-validation\n",
    "  train_prop: 0.5                     # Training proportion in case of 'random' splitting strategy\n",
    "  standardize: True                   # Standardize the data or not. Options: [True, False]\n",
    "  data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}\n",
    "```\n",
    "\n",
    "4) Finally the dataloader parameters:\n",
    "\n",
    "```yaml\n",
    "# Dataloader parameters\n",
    "dataloader_params:\n",
    "  batch_size: 1       # Number of graphs per batch. In sace of transductive always 1 as there is only one graph. \n",
    "  num_workers: 0      # Number of workers for data loading\n",
    "  pin_memory: False   # Pin memory for data loading\n",
    "```\n",
    "\n",
    "### Notes:\n",
    "- The `paths` section in the configuration file is automatically populated with the paths to the data directory and the data splits directory.\n",
    "- Some of the dataset parameters are used to configure the model.yaml and other files. Hence we suggest always include the above parameters in the dataset configuration file.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's the markdown for easy copying:\n",
    "\n",
    "\n",
    "## Preparing to Load the Custom Dataset: Understanding Configuration Imports\n",
    "\n",
    "Before loading our dataset, it's crucial to understand the configuration imports, particularly those from the `topobench.utils.config_resolvers` module. These utility functions play a key role in dynamically configuring your machine learning pipeline.\n",
    "\n",
    "### Key Imports for Dynamic Configuration\n",
    "\n",
    "Let's import the essential configuration resolver functions:\n",
    "\n",
    "```python\n",
    "from topobench.utils.config_resolvers import (\n",
    "    get_default_transform,\n",
    "    get_monitor_metric,\n",
    "    get_monitor_mode,\n",
    "    infer_in_channels,\n",
    ")\n",
    "```\n",
    "\n",
    "### Why These Imports Matter\n",
    "\n",
    "In our previous step, we explored configuration variables that use dynamic lookups, such as:\n",
    "\n",
    "```yaml\n",
    "data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}\n",
    "```\n",
    "\n",
    "However, some configurations require more advanced automation, which is where these imported functions become invaluable.\n",
    "\n",
    "### Practical Example: Dynamic Transforms\n",
    "\n",
    "Consider the configuration in `projects/TopoBench/configs/run.yaml`, where the `transforms` parameter uses the `get_default_transform` function:\n",
    "\n",
    "```yaml\n",
    "transforms: ${get_default_transform:${dataset},${model}}\n",
    "```\n",
    "\n",
    "This syntax allows for automatic transformation selection based on the dataset and model, demonstrating the power of these configuration resolver functions.\n",
    "\n",
    "By importing and utilizing these functions, you gain:\n",
    "- Flexible configuration management\n",
    "- Automatic parameter inference\n",
    "- Reduced manual configuration overhead\n",
    "\n",
    "These facilitate seamless dataset loading and preprocessing for multiple topological domains and provide an easy and intuitive interface for incorporating novel functionality.\n",
    "```\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1170891/1713955081.py:14: UserWarning: \n",
      "The version_base parameter is not specified.\n",
      "Please specify a compatability version level, or None.\n",
      "Will assume defaults for version 1.1\n",
      "  initialize(config_path=\"../configs\", job_name=\"job\")\n"
     ]
    }
   ],
   "source": [
    "from hydra import compose, initialize\n",
    "from hydra.utils import instantiate\n",
    "\n",
    "\n",
    "\n",
    "from topobench.utils.config_resolvers import (\n",
    "    get_default_transform,\n",
    "    get_monitor_metric,\n",
    "    get_monitor_mode,\n",
    "    infer_in_channels,\n",
    ")\n",
    "\n",
    "\n",
    "initialize(config_path=\"../configs\", job_name=\"job\")\n",
    "cfg = compose(\n",
    "    config_name=\"run.yaml\",\n",
    "    overrides=[\n",
    "        \"model=hypergraph/unignn2\",\n",
    "        \"dataset=graph/US-county-demos\",\n",
    "    ], \n",
    "    return_hydra_config=True\n",
    ")\n",
    "loader = instantiate(cfg.dataset.loader)\n",
    "\n",
    "\n",
    "dataset, dataset_dir = loader.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "US-county-demos(self.root=/home/lev/projects/TopoBench/datasets/graph/cornel, self.name=US-county-demos, self.parameters={'data_domain': 'graph', 'data_type': 'cornel', 'data_name': 'US-county-demos', 'year': 2012, 'task_variable': 'Election', 'data_dir': '/home/lev/projects/TopoBench/datasets/graph/cornel'}, self.force_reload=False)\n"
     ]
    }
   ],
   "source": [
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[3224, 6], edge_index=[2, 18966], y=[3224])\n"
     ]
    }
   ],
   "source": [
    "print(dataset[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 4.1: Default Data Transformations ⚙️"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While most datasets can be used directly after integration, some require specific preprocessing transformations. These transformations might vary depending on the task, model, or other conditions.\n",
    "\n",
    "## Example Case: US-county-demos Dataset\n",
    "\n",
    "Let's look at our language dataset's structure the `compose` function. \n",
    "```python\n",
    "cfg = compose(\n",
    "    config_name=\"run.yaml\",\n",
    "    overrides=[\n",
    "        \"model=hypergraph/unignn2\",\n",
    "        \"dataset=graph/US-county-demos\",\n",
    "    ], \n",
    "    return_hydra_config=True\n",
    ")\n",
    "```\n",
    "we can see that the model is `hypergraph/unignn2` from hypergraph domain while the dataset is from graph domain.\n",
    "This implied that the discussed above `get_default_transform` function:\n",
    "\n",
    "```yaml\n",
    "transforms: ${get_default_transform:${dataset},${model}}\n",
    "```\n",
    "Inferred a default transform from graph to hypegraph domain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transform name: dict_keys(['graph2hypergraph_lifting'])\n",
      "Transform parameters: {'_target_': 'topobench.transforms.data_transform.DataTransform', 'transform_type': 'lifting', 'transform_name': 'HypergraphKHopLifting', 'k_value': 1, 'feature_lifting': 'ProjectionSum', 'neighborhoods': '${oc.select:model.backbone.neighborhoods,null}'}\n"
     ]
    }
   ],
   "source": [
    "print('Transform name:', cfg.transforms.keys())\n",
    "print('Transform parameters:', cfg.transforms['graph2hypergraph_lifting'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some datasets require might require default transforms which are applied whenever it is nedded to model the data. \n",
    "\n",
    "The topobench library provides a simple way to define custom transformations and apply them to the dataset.\n",
    "Take a look at `TopoBench/configs/transforms/dataset_defaults` folder where you can find some default transformations for different datasets.\n",
    "\n",
    "For example, REDDIT-BINARY does not have initial node features and it is a common practice to define initial features as gaussian noise.\n",
    "Hence the `TopoBench/configs/transforms/dataset_defaults/REDDIT-BINARY.yaml` file incorporates the `gaussian_noise` transform by default. \n",
    "Hence whenver you choose to uplodad the REDDIT-BINARY dataset (and do not modify ```transforms``` parameter), the `gaussian_noise` transform will be applied to the dataset.\n",
    "\n",
    "```yaml\n",
    "defaults:\n",
    "  - data_manipulations: equal_gaus_features\n",
    "  - liftings@_here_: ${get_required_lifting:graph,${model}}\n",
    "```\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "Below we provide an quick tutorial on how to create a data transformations and create a sequence of default transformations that will be executed whener you use the defined dataset config file.\n",
    "\n",
    "\n",
    "\n",
    "Below we provide an quick tutorial on how to create a data transformations and create a sequence of default transformations that will be executed whener you use the defined dataset config file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Avoid override transforms\n",
    "cfg = compose(\n",
    "    config_name=\"run.yaml\",\n",
    "    overrides=[\n",
    "        \"model=hypergraph/unignn2\",\n",
    "        \"dataset=graph/REDDIT-BINARY\",\n",
    "    ], \n",
    "    return_hydra_config=True\n",
    ")\n",
    "loader = instantiate(cfg.dataset.loader)\n",
    "\n",
    "\n",
    "dataset, dataset_dir = loader.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "REDDIT_BINARY dataset does not have any initial node features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(edge_index=[2, 480], y=[1], num_nodes=218)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Take a look at the default transforms and the parameters of `equal_gaus_features` transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transform name: dict_keys(['equal_gaus_features', 'graph2hypergraph_lifting'])\n",
      "Transform parameters: {'_target_': 'topobench.transforms.data_transform.DataTransform', 'transform_name': 'EqualGausFeatures', 'transform_type': 'data manipulation', 'mean': 0, 'std': 0.1, 'num_features': '${dataset.parameters.num_features}'}\n"
     ]
    }
   ],
   "source": [
    "print('Transform name:', cfg.transforms.keys())\n",
    "print('Transform parameters:', cfg.transforms['equal_gaus_features'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "from topobench.data.preprocessor import PreProcessor\n",
    "preprocessed_dataset = PreProcessor(dataset, dataset_dir, cfg['transforms'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(x=[218, 10], edge_index=[2, 480], y=[1], incidence_hyperedges=[218, 218], num_hyperedges=[1], x_0=[218, 10], x_hyperedges=[218, 10], num_nodes=218)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessed_dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The preprocessed dataset has the features generated by the preprocessor. And the connectivity of the dataset has been transformed into hypegraph domain. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating your own default transforms\n",
    "\n",
    "Now when we have seen how to add custom dataset and how does the default transform works. One might want to reate your own default transforms for new dataset that will be executed always whenwever the dataset under default configuration is used.\n",
    "\n",
    "\n",
    "**To configure** the deafult transform navigate to `configs/transforms/dataset_defaults` create `<def_transforms.yaml>` and the follwoing `.yaml` file: \n",
    "\n",
    "```yaml\n",
    "defaults:\n",
    "  - transform_1: transform_1\n",
    "  - transform_2: transform_2\n",
    "  - transform_3: transform_3\n",
    "```\n",
    "\n",
    "\n",
    "**Important**\n",
    "There are different types of transforms, including `data_manipulation`, `liftings`, and `feature_liftings`. In case you want to use multiple transforms from the same categoty, let's say from `data_manipulation`, then it is required to stick to a special syntaxis. [See hydra configuration for more information]() or the example below: \n",
    "\n",
    "```yaml\n",
    "defaults:\n",
    "  - data_manipulation@first_usage: transform_1\n",
    "  - data_manipulation@second_usage: transform_2\n",
    "```\n",
    "\n",
    "\n",
    "### Notes: \n",
    "\n",
    "- **Transforms from the same category:** If There are a two transforms from the same catgory, for example, `data_manipulations`, it is required to use operator `@` to assign new diffrerent names `first_usage` and `second_usage` to each transform.\n",
    "\n",
    "-  In the case of `equal_gaus_features` we have to override the initial number of features as the `equal_gaus_features.yaml` which uses a special register to infer the feature dimension (the registed logic descrived in Step 3.) However by some reason we want to specify `num_features` parameter we can override it in the default file without the need to change the transform config file. \n",
    "\n",
    "```yaml\n",
    "defaults:\n",
    "  - data_manipulations@equal_gaus_features: equal_gaus_features\n",
    "  - data_manipulations@some_transform: some_transform\n",
    "  - liftings@_here_: ${get_required_lifting:graph,${model}}\n",
    "\n",
    "equal_gaus_features:\n",
    "  num_features: 100\n",
    "some_transform:\n",
    "  some_param: bla\n",
    "```\n",
    "\n",
    "- We recommend to always add `liftings@_here_: ${get_required_lifting:graph,${model}}` so that a default lifting is applied to run any domain-specific topological model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 4.2: Custom Data Transformations ⚙️"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a Transform\n",
    "\n",
    "In general any transfom in the library inherits `torch_geometric.transforms.BaseTransform` class, which allow to apply a sequency of transforms to the data. Our inderface requires to implement the `forward` method. The important part of all transforms is that it takes `torch_geometric.data.Data` object and returns updated `torch_geometric.data.Data` object.\n",
    "\n",
    "\n",
    "\n",
    "For language dataset,  we have generated the `equal_gaus_features` transfroms that is a data_manipulation transform hence we place it into `topobench/transforms/data_manipulation/` folder. \n",
    "Below you can see th `EqualGausFeatures` class: \n",
    "\n",
    "\n",
    "```python\n",
    "   class EqualGausFeatures(torch_geometric.transforms.BaseTransform):\n",
    "    r\"\"\"A transform that generates equal Gaussian features for all nodes.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    **kwargs : optional\n",
    "        Additional arguments for the class. It should contain the following keys:\n",
    "        - mean (float): The mean of the Gaussian distribution.\n",
    "        - std (float): The standard deviation of the Gaussian distribution.\n",
    "        - num_features (int): The number of features to generate.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__()\n",
    "        self.type = \"generate_non_informative_features\"\n",
    "\n",
    "        # Torch generate feature vector from gaus distribution\n",
    "        self.mean = kwargs[\"mean\"]\n",
    "        self.std = kwargs[\"std\"]\n",
    "        self.feature_vector = kwargs[\"num_features\"]\n",
    "        self.feature_vector = torch.normal(\n",
    "            mean=self.mean, std=self.std, size=(1, self.feature_vector)\n",
    "        )\n",
    "\n",
    "    def __repr__(self) -> str:\n",
    "        return f\"{self.__class__.__name__}(type={self.type!r}, mean={self.mean!r}, std={self.std!r}, feature_vector={self.feature_vector!r})\"\n",
    "\n",
    "    def forward(self, data: torch_geometric.data.Data):\n",
    "        r\"\"\"Apply the transform to the input data.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        data : torch_geometric.data.Data\n",
    "            The input data.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        torch_geometric.data.Data\n",
    "            The transformed data.\n",
    "        \"\"\"\n",
    "        data.x = self.feature_vector.expand(data.num_nodes, -1)\n",
    "        return data\n",
    "\n",
    "```\n",
    "\n",
    "As we said above the `forward` function takes as input the `torch_geometric.data.Data` object, modifies it, and returns it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Register the Transform\n",
    "\n",
    "Similarly to adding dataset the transformations you have created and placed at right folder are automatically registered.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a configuration file \n",
    "Now as we have registered the transform we can finally create the configuration file and use it in the framework: \n",
    "\n",
    "``` yaml\n",
    "_target_: topobench.transforms.data_transform.DataTransform\n",
    "transform_name: \"EqualGausFeatures\"\n",
    "transform_type: \"data manipulation\"\n",
    "\n",
    "mean: 0\n",
    "std: 0.1\n",
    "num_features: ${dataset.parameters.num_features}\n",
    "``` \n",
    "Please refer to `configs/transforms/dataset_defaults/equal_gaus_features.yaml` for the example. \n",
    "\n",
    "**Notes:**\n",
    "\n",
    "- You might notice an interesting key `_target_` in the configuration file. In general for any new transform you the `_target_` is always `topobench.transforms.data_transform.DataTransform`.  [For more information please refer to hydra documentation \"Instantiating objects with Hydra\" section.](https://hydra.cc/docs/advanced/instantiate_objects/overview/). "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tb",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
